Customer segmentation, also known as market basket analysis, is a very useful concept in marketing. By identifying unique customer traits, businesses understand their customers on a deeper level, allowing more strategic marketing and advertising to target different groups of customers.
Using K-means clustering, which is an unsupervised machine learning technique, we can group similar customers and identify several types of customer profile.
This dataset consists of hypothetical customer data in a shopping mall.
Data source: https://www.kaggle.com/datasets/vjchoudhary7/customer-segmentation-tutorial-in-python
Sections:
This dataset has five features:
1. CustomerID: Unique ID assigned to the customer
2. Gender: Gender of the customer
3. Age: Age of the customer
4. Annual Income (k$): Annual income of the customer, in thousand dollars
5. Spending Score (1-100): Score assigned by the mall based on customer behavior and spending nature, ranging from 0 to 100
In total, there are 200 customer records in this dataset.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly as py
import plotly.graph_objs as go
from sklearn.cluster import KMeans
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv("Mall_Customers.csv")
#observe samples of data
df.head()
| CustomerID | Gender | Age | Annual Income (k$) | Spending Score (1-100) | |
|---|---|---|---|---|---|
| 0 | 1 | Male | 19 | 15 | 39 |
| 1 | 2 | Male | 21 | 15 | 81 |
| 2 | 3 | Female | 20 | 16 | 6 |
| 3 | 4 | Female | 23 | 16 | 77 |
| 4 | 5 | Female | 31 | 17 | 40 |
#observe data types
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 200 entries, 0 to 199 Data columns (total 5 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 CustomerID 200 non-null int64 1 Gender 200 non-null object 2 Age 200 non-null int64 3 Annual Income (k$) 200 non-null int64 4 Spending Score (1-100) 200 non-null int64 dtypes: int64(4), object(1) memory usage: 7.9+ KB
#check for missing values
df.isnull().sum()
CustomerID 0 Gender 0 Age 0 Annual Income (k$) 0 Spending Score (1-100) 0 dtype: int64
There is no missing data.
#overview of feature statistics
df.describe()
| CustomerID | Age | Annual Income (k$) | Spending Score (1-100) | |
|---|---|---|---|---|
| count | 200.000000 | 200.000000 | 200.000000 | 200.000000 |
| mean | 100.500000 | 38.850000 | 60.560000 | 50.200000 |
| std | 57.879185 | 13.969007 | 26.264721 | 25.823522 |
| min | 1.000000 | 18.000000 | 15.000000 | 1.000000 |
| 25% | 50.750000 | 28.750000 | 41.500000 | 34.750000 |
| 50% | 100.500000 | 36.000000 | 61.500000 | 50.000000 |
| 75% | 150.250000 | 49.000000 | 78.000000 | 73.000000 |
| max | 200.000000 | 70.000000 | 137.000000 | 99.000000 |
plt.figure(1 , figsize = (15 , 6))
features = ['Age', 'Annual Income (k$)', 'Spending Score (1-100)']
pos = 1
for f in features:
plt.subplot(1, 3, pos)
plt.subplots_adjust(hspace = 0.2, wspace = 0.3)
sns.distplot(df[f])
pos += 1
plt.show()
By plotting histogram, the distribution of the features can be observed.
sns.countplot(x = "Gender", data = df)
plt.title("Number of male and female customers")
plt.show()
There are more female customers.
#correlation to identify relationship between features
sns.heatmap(df.iloc[:,1:].corr(), annot = True, cmap = plt.cm.Greens) #exclude customerID
plt.title("Heatmap of features correlation")
plt.show()
sns.pairplot(df, vars=['Age', 'Annual Income (k$)', 'Spending Score (1-100)'], hue = 'Gender', kind = 'reg')
plt.show()
Plotting a heatmap of the correlation between the features shows that there are very weak associations. However, from the pairplot, we observe that there are stronger associations between the features when separating them by gender.
For example, there is moderate negative correlation between age and spending score.
We also observe that the female customers are older, have higher annual income and spending score.
sns.lmplot(x = 'Age', y = 'Annual Income (k$)', data = df, hue = "Gender")
plt.title('Scatterplot of age and annual income')
plt.show()
Customers between age 30 - 60 have higher annual income in general.
sns.lmplot(x = 'Age', y = 'Spending Score (1-100)', data = df, hue = "Gender")
plt.title('Scatterplot of age and spending score')
plt.show()
We can see that young people have high spending score, while customers who are 30 years old or above have low or moderate spending score.
sns.lmplot(x = 'Annual Income (k$)', y = 'Spending Score (1-100)', data = df, hue = "Gender")
plt.title('Scatterplot of annual income and spending score')
plt.show()
We can observe 5 groups of customers here:
We will group customers using 3 features: Annual income, spending score and age.
To determine the appropriate number of clusters, we will be using the elbow method.
#features for training
X = df[['Age', 'Annual Income (k$)', 'Spending Score (1-100)']]
#Calculate the Within-Cluster Sum of Square (WCSS)
wcss = []
for i in range(1, 11):
kmeans = KMeans(n_clusters = i, init = 'k-means++', random_state = 5)
kmeans.fit(X)
wcss.append(kmeans.inertia_)
#plot elbow curve
plt.plot(range(1, 11), wcss, marker ="8")
plt.xlabel('Number of clusters')
plt.xticks(np.arange(1,11,1)) #to set x axis grid
plt.ylabel('WCSS')
plt.title('Elbow curve')
plt.show()
Based on the graph, WCSS decreases sharply and the elbow shape is created at k = 5. The optimal number of clusters for the model is 5.
#define model with 5 clusters
km = KMeans(n_clusters = 5, init = "k-means++", random_state = 5)
#fit input data to train model, and predict labels
y = km.fit_predict(X)
#add labels to dataframe
df['Label'] = y
df.head()
| CustomerID | Gender | Age | Annual Income (k$) | Spending Score (1-100) | Label | |
|---|---|---|---|---|---|---|
| 0 | 1 | Male | 19 | 15 | 39 | 2 |
| 1 | 2 | Male | 21 | 15 | 81 | 1 |
| 2 | 3 | Female | 20 | 16 | 6 | 2 |
| 3 | 4 | Female | 23 | 16 | 77 | 1 |
| 4 | 5 | Female | 31 | 17 | 40 | 2 |
#Create a 3d plot to view the clusters determined by the mode
trace1 = go.Scatter3d(
x= df['Spending Score (1-100)'],
y= df['Annual Income (k$)'],
z= df['Age'],
mode='markers',
marker=dict(
color = df['Label'],
size= 5,
line=dict(
color= df['Label'],
),
opacity = 0.9
)
)
layout = go.Layout(
title= 'Clusters',
scene = dict(
xaxis = dict(title = 'Spending Score (1-100)'),
yaxis = dict(title = 'Annual Income (k$)'),
zaxis = dict(title = 'Age')
)
)
fig = go.Figure(data=trace1, layout=layout)
# py.offline.iplot(fig)
fig.show("notebook")
1. Blue cluster
2. Orange cluster
3. Yellow cluster
4. Purple cluster
5. Pink cluster